Gemma/[Gemma_3]Gradio_LlamaCpp_Chatbot.ipynb (525 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "fLtNzk95zuks" }, "source": [ "##### Copyright 2025 Google LLC." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "fnRDEtTRz1Ah" }, "outputs": [], "source": [ "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "IJK05W_Rz9I3" }, "source": [ "# Build Chatbot with Gradio & Llama.cpp\n", "\n", "Author: Sitam Meur\n", "\n", "* GitHub: [github.com/sitamgithub-MSIT](https://github.com/sitamgithub-MSIT/)\n", "* X: [@sitammeur](https://x.com/sitammeur)\n", "\n", "Description: Google recently released Gemma 3 QAT—the [Quantization Aware Trained (QAT) Gemma 3](https://huggingface.co/collections/google/gemma-3-qat-67ee61ccacbf2be4195c265b) checkpoints. These models maintain similar quality to half precision while using three times less memory. This notebook demonstrates creating a user-friendly chat interface for the [gemma-3-1b-it-qat](https://huggingface.co/google/gemma-3-1b-it-qat-q4_0-gguf) text model using Llama.cpp (for inference) and Gradio (for user interface).\n", "\n", "<table align=\"left\">\n", " <td>\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_3]Gradio_LlamaCpp_Chatbot.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", " </td>\n", "</table>" ] }, { "cell_type": "markdown", "metadata": { "id": "JEogAuER0MPC" }, "source": [ "## Setup\n", "\n", "### Select the Colab runtime\n", "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma 3 QAT model. In this case, you can use a CPU:\n", "\n", "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n", "2. Select **Change runtime type**.\n", "3. Under **Hardware accelerator**, select **CPU**.\n", "\n", "### Gemma Setup\n", "\n", "**Before we dive into the tutorial, let's get you set up:**\n", "\n", "1. **Hugging Face Account:** If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n", "2. **Model Access:** Head over to the [model page](https://huggingface.co/google/gemma-3-1b-it-qat-q4_0-gguf) and accept the usage conditions.\n", "3. **Hugging Face Token:** Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n", "\n", "**Once you've completed these steps, you're ready to move on to the next section where we'll set up environment variables in your Colab environment.**" ] }, { "cell_type": "markdown", "metadata": { "id": "D3DIfibe1I4W" }, "source": [ "### Configure your HF token\n", "\n", "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n", "\n", "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n", "2. Create a new secret with the name `HF_TOKEN`.\n", "3. Copy/paste your token key into the Value input box of `HF_TOKEN`.\n", "4. Toggle the button on the left to allow notebook access to the secret." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ox59ZaemK3j1" }, "outputs": [], "source": [ "import os\n", "from google.colab import userdata\n", "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", "# vars as appropriate for your system.\n", "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ircX6fQA05Wr" }, "source": [ "### Install dependencies\n", "Run the cell below to install all the required dependencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hFoZpGiQEiBW" }, "outputs": [], "source": [ "!pip install -q huggingface_hub scikit-build-core llama-cpp-python llama-cpp-agent gradio" ] }, { "cell_type": "markdown", "metadata": { "id": "gko1egPaLEyy" }, "source": [ "### Log into Hugging Face Hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GiBP1bfhLFTz" }, "outputs": [], "source": [ "from huggingface_hub import login\n", "\n", "login(os.environ[\"HF_TOKEN\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "QRO7WkLaLeW6" }, "source": [ "## Instantiate the Gemma 3 QAT model\n", "\n", "We’ll use the Hugging Face Hub to download the 4-bit (Q4_0) quantized Gemma 3 1B instruction-tuned text-only model. New quantized Gemma 3 models, created using Quantization-Aware Training (QAT), offer improved accessibility through reduced memory usage (3x less than bfloat16) without significant accuracy loss. QAT simulates low-precision operations during training for smaller, faster models.\n", "\n", "Let's get started by downloading the model from Hugging Face Hub." ] }, { "cell_type": "markdown", "metadata": { "id": "Bi5cfGCP2TR1" }, "source": [ "### Loading the model from HF Hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BGNdy0NsqAo6" }, "outputs": [], "source": [ "# Download gguf model files\n", "from huggingface_hub import hf_hub_download\n", "\n", "if not os.path.exists(\"./models\"):\n", " os.makedirs(\"./models\")\n", "\n", "hf_hub_download(\n", " repo_id=\"google/gemma-3-1b-it-qat-q4_0-gguf\",\n", " filename=\"gemma-3-1b-it-q4_0.gguf\",\n", " local_dir=\"./models\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "fYnY8mpQ2R9w" }, "source": [ "### Prompt Formatting\n", "\n", "Gemma 3 model requires specific formatting to understand the roles of different participants in a conversation. The prompt format is as follows:\n", "\n", "```\n", "<bos><start_of_turn>user\n", "{system_prompt}\n", "\n", "{prompt}<end_of_turn>\n", "<start_of_turn>model\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IOPoBYD3IjFV" }, "outputs": [], "source": [ "# Define the prompt markers for Gemma 3\n", "from llama_cpp_agent.chat_history.messages import Roles\n", "from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers\n", "\n", "gemma_3_prompt_markers = {\n", " Roles.system: PromptMarkers(\"\", \"\\n\"), # System prompt should be included within user message\n", " Roles.user: PromptMarkers(\"<start_of_turn>user\\n\", \"<end_of_turn>\\n\"),\n", " Roles.assistant: PromptMarkers(\"<start_of_turn>model\\n\", \"<end_of_turn>\\n\"),\n", " Roles.tool: PromptMarkers(\"\", \"\"), # If need tool support\n", "}\n", "\n", "# Create the formatter\n", "gemma_3_formatter = MessagesFormatter(\n", " pre_prompt=\"\", # No pre-prompt\n", " prompt_markers=gemma_3_prompt_markers,\n", " include_sys_prompt_in_first_user_message=True, # Include system prompt in first user message\n", " default_stop_sequences=[\"<end_of_turn>\", \"<start_of_turn>\"],\n", " strip_prompt=False, # Don't strip whitespace from the prompt\n", " bos_token=\"<bos>\", # Beginning of sequence token for Gemma 3\n", " eos_token=\"<eos>\", # End of sequence token for Gemma 3\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "V92UpXaE2Wff" }, "source": [ "## Chat with Gemma 3\n", "\n", "This function handles loading the model and generating responses. Streaming provides real-time generation rather than waiting for the complete response." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dB_YfAdX5JTL" }, "outputs": [], "source": [ "from typing import List, Tuple\n", "from llama_cpp import Llama\n", "from llama_cpp_agent import LlamaCppAgent\n", "from llama_cpp_agent.providers import LlamaCppPythonProvider\n", "from llama_cpp_agent.chat_history import BasicChatHistory\n", "from llama_cpp_agent.chat_history.messages import Roles\n", "\n", "llm = None\n", "llm_model = None\n", "\n", "def respond(\n", " message: str,\n", " history: List[Tuple[str, str]],\n", " model: str = \"gemma-3-1b-it-q4_0.gguf\",\n", " system_message: str = \"You are a helpful assistant.\",\n", " max_tokens: int = 1024,\n", " temperature: float = 0.7,\n", " top_p: float = 0.95,\n", " top_k: int = 40,\n", " repeat_penalty: float = 1.1,\n", "):\n", " \"\"\"\n", " Respond to a message using the Gemma3 model via Llama.cpp.\n", "\n", " Args:\n", " - message (str): The message to respond to.\n", " - history (List[Tuple[str, str]]): The chat history.\n", " - model (str): The model to use.\n", " - system_message (str): The system message to use.\n", " - max_tokens (int): The maximum number of tokens to generate.\n", " - temperature (float): The temperature of the model.\n", " - top_p (float): The top-p of the model.\n", " - top_k (int): The top-k of the model.\n", " - repeat_penalty (float): The repetition penalty of the model.\n", "\n", " Returns:\n", " str: The response to the message.\n", " \"\"\"\n", " try:\n", " # Load the global variables\n", " global llm\n", " global llm_model\n", "\n", " # Ensure model is not None\n", " if model is None:\n", " model = \"gemma-3-1b-it-q4_0.gguf\"\n", "\n", " # Load the model\n", " if llm is None or llm_model != model:\n", " # Check if model file exists\n", " model_path = f\"models/{model}\"\n", " if not os.path.exists(model_path):\n", " yield f\"Error: Model file not found at {model_path}. Please check your model path.\"\n", " return\n", "\n", " llm = Llama(\n", " model_path=f\"models/{model}\",\n", " flash_attn=False,\n", " n_gpu_layers=0,\n", " n_batch=8,\n", " n_ctx=2048,\n", " n_threads=8,\n", " n_threads_batch=8,\n", " )\n", " llm_model = model\n", " provider = LlamaCppPythonProvider(llm)\n", "\n", " # Create the agent\n", " agent = LlamaCppAgent(\n", " provider,\n", " system_prompt=f\"{system_message}\",\n", " custom_messages_formatter=gemma_3_formatter,\n", " debug_output=True,\n", " )\n", "\n", " # Set the settings like temperature, top-k, top-p, max tokens, etc.\n", " settings = provider.get_provider_default_settings()\n", " settings.temperature = temperature\n", " settings.top_k = top_k\n", " settings.top_p = top_p\n", " settings.max_tokens = max_tokens\n", " settings.repeat_penalty = repeat_penalty\n", " settings.stream = True\n", "\n", " messages = BasicChatHistory()\n", "\n", " # Add the chat history\n", " for msn in history:\n", " user = {\"role\": Roles.user, \"content\": msn[0]}\n", " assistant = {\"role\": Roles.assistant, \"content\": msn[1]}\n", " messages.add_message(user)\n", " messages.add_message(assistant)\n", "\n", " # Get the response stream\n", " stream = agent.get_chat_response(\n", " message,\n", " llm_sampling_settings=settings,\n", " chat_history=messages,\n", " returns_streaming_generator=True,\n", " print_output=False,\n", " )\n", "\n", " # Generate the response\n", " outputs = \"\"\n", " for output in stream:\n", " outputs += output\n", " yield outputs\n", "\n", " # Handle exceptions that may occur during the process\n", " except Exception as e:\n", " raise Exception(f\"An error occurred: {str(e)}\") from e" ] }, { "cell_type": "markdown", "metadata": { "id": "F1PbdoSr2Y-u" }, "source": [ "## Gradio UI\n", "\n", "This offers a user interface with customizable buttons, editable messages, and options for advanced users to tune model parameters like temperature and top-p." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2vu56cxP548d" }, "outputs": [], "source": [ "# Create a chat interface\n", "import gradio as gr\n", "\n", "demo = gr.ChatInterface(\n", " respond,\n", " examples=[[\"What is the capital of France?\"], [\"Tell me something about artificial intelligence.\"], [\"What is gravity?\"]],\n", " additional_inputs_accordion=gr.Accordion(\n", " label=\"⚙️ Parameters\", open=False, render=False\n", " ),\n", " additional_inputs=[\n", " gr.Dropdown(\n", " choices=[\n", " \"gemma-3-1b-it-q4_0.gguf\",\n", " ],\n", " value=\"gemma-3-1b-it-q4_0.gguf\",\n", " label=\"Model\",\n", " info=\"Select the AI model to use for chat\",\n", " ),\n", " gr.Textbox(\n", " value=\"You are a helpful assistant.\",\n", " label=\"System Prompt\",\n", " info=\"Define the AI assistant's personality and behavior\",\n", " lines=2,\n", " ),\n", " gr.Slider(\n", " minimum=512,\n", " maximum=2048,\n", " value=1024,\n", " step=1,\n", " label=\"Max Tokens\",\n", " info=\"Maximum length of response (higher = longer replies)\",\n", " ),\n", " gr.Slider(\n", " minimum=0.1,\n", " maximum=2.0,\n", " value=0.7,\n", " step=0.1,\n", " label=\"Temperature\",\n", " info=\"Creativity level (higher = more creative, lower = more focused)\",\n", " ),\n", " gr.Slider(\n", " minimum=0.1,\n", " maximum=1.0,\n", " value=0.95,\n", " step=0.05,\n", " label=\"Top-p\",\n", " info=\"Nucleus sampling threshold\",\n", " ),\n", " gr.Slider(\n", " minimum=1,\n", " maximum=100,\n", " value=40,\n", " step=1,\n", " label=\"Top-k\",\n", " info=\"Limit vocabulary choices to top K tokens\",\n", " ),\n", " gr.Slider(\n", " minimum=1.0,\n", " maximum=2.0,\n", " value=1.1,\n", " step=0.1,\n", " label=\"Repetition Penalty\",\n", " info=\"Penalize repeated words (higher = less repetition)\",\n", " ),\n", " ],\n", " submit_btn=\"Send\",\n", " stop_btn=\"Stop\",\n", " chatbot=gr.Chatbot(scale=1, show_copy_button=True, resizable=True),\n", " flagging_mode=\"never\",\n", " editable=True,\n", " cache_examples=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "wKGE9W0ooEO3" }, "source": [ "Finally, we’ll launch our chat interface with Gradio" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kmhrakxX58av" }, "outputs": [], "source": [ "# Launch the chat interface\n", "demo.launch(debug=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "gxlqWI812bJL" }, "source": [ "## What's Next?\n", "\n", "That's it! Here are some ideas to explore further:\n", "\n", "- **Explore the Gemma family models:** Visit [Gemma Open Models](https://ai.google.dev/gemma) to learn about the latest updates regarding the Gemma family models, new capabilities, versions, and more.\n", "\n", "- **Gradio Customization:** Explore the [Gradio documentation](https://www.gradio.app/docs) to learn about customizing your chat interface, adding new options and features.\n", "\n", "- **Share Your Gradio Dashboard:** Check out the [Sharing your Gradio app](https://www.gradio.app/guides/sharing-your-app) page to learn how to safely share your Gradio dashboard with others!" ] } ], "metadata": { "colab": { "name": "[Gemma_3]Gradio_LlamaCpp_Chatbot.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }